#!/usr/bin/env python3
"""
Entry point for the mass-gap simulation integration (real implementation).

Loads the shared configs/default.yaml, iterates over (b,k,n0,L),
calls run_mass_gap from orig/mass_gap_core.py to get all replicates,
and writes both full and summary CSVs.

This version is robust to different config layouts:
- Accepts flip-count and kernel **templates** (e.g., ".../L{L}/flip_counts_L{L}.npy")
- Accepts per-gauge kernel path dicts (e.g., kernel_paths.SU2)
- Falls back across several common locations without requiring YAML changes
"""

import argparse
import pathlib
import sys
import numpy as np
from typing import Any, Dict, Optional

# Make sure the repo root is on PYTHONPATH
repo_root = pathlib.Path(__file__).resolve().parents[1]
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

from sim_utils import load_config, sweep_iter, seed_all, save_csv
from orig.mass_gap_core import run_mass_gap


def deep_get(d: Dict[str, Any], dotted_key: str, default: Any = None) -> Any:
    """Get a nested key using dot-notation, e.g. 'a.b.c'."""
    cur = d
    for part in dotted_key.split("."):
        if not isinstance(cur, dict) or part not in cur:
            return default
        cur = cur[part]
    return cur


def resolve_flip_template(cfg: Dict[str, Any]) -> str:
    """
    Try multiple likely locations for the flip-counts path/template
    without requiring YAML edits.
    """
    candidates = [
        "flip_counts_path",                         # direct path
        "flip_counts_path_template",                # direct template
        "paths.flip_counts_path",                   # nested 'paths'
        "vol4_loop_fluctuation_sim.flip_counts_path_template",
        "vol4_loop_fluctuation_sim.flip_counts_template",
        "vol4_loop_fluctuation_sim.paths.flip_counts_path",
        "adjoint_volume.flip_counts_path_template",
    ]
    for key in candidates:
        v = deep_get(cfg, key)
        if isinstance(v, str) and v.strip():
            return v
    tried = ", ".join(candidates)
    raise KeyError("No flip-counts template found. Looked for: " + tried)


def resolve_kernel_spec(cfg: Dict[str, Any], gauge: str) -> str:
    """
    Try multiple likely locations for the kernel path/template.
    May be a single string or a per-gauge dict entry.
    """
    candidates = [
        "kernel_path",                               # direct path
        "kernel_template",                           # direct template
        f"kernel_paths.{gauge}",                     # per-gauge dict
        f"vol4_wilson_loop_adjoint_volume_sweep.kernel_paths.{gauge}",
        f"vol4_discrete_gauge_wilson_loop.kernel_paths.{gauge}",
        f"paths.kernel_paths.{gauge}",
    ]
    for key in candidates:
        v = deep_get(cfg, key)
        if isinstance(v, str) and v.strip():
            return v
    tried = ", ".join(candidates)
    raise KeyError(
        f"No kernel path/template found for gauge={gauge}. Looked for: " + tried
    )


def format_template(spec: str, **kwargs: Any) -> str:
    """Format a template only if it contains braces; otherwise return as-is."""
    return spec.format(**kwargs) if "{" in spec else spec


def main() -> None:
    parser = argparse.ArgumentParser(description="Run the mass gap simulation")
    parser.add_argument("--config", "-c", required=True,
                        help="Path to configs/default.yaml")
    parser.add_argument("--output-dir", "-o", required=True,
                        help="Directory to write summary CSV")
    args = parser.parse_args()

    # Load master YAML
    cfg: Dict[str, Any] = load_config(args.config)

    out_dir = pathlib.Path(args.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    # Prepare full-results CSV folder inside this module
    module_root = pathlib.Path(__file__).resolve().parent
    full_results_dir = module_root / "results"
    full_results_dir.mkdir(parents=True, exist_ok=True)
    full_csv = full_results_dir / "mass_gap_full.csv"

    # Summary CSV (in requested output dir)
    summary_csv = out_dir / "mass_gap_summary.csv"

    # Mass-gap meta-config
    mg_cfg: Dict[str, Any] = cfg.get("mass_gap", {}) or {}
    ensemble_size: int = int(mg_cfg.get("ensemble_size", 10))
    gauge: str = str(mg_cfg.get("gauge", "SU2"))

    # Resolve templates/specs from various possible locations
    flip_template: str = resolve_flip_template(cfg)
    kernel_spec: str = resolve_kernel_spec(cfg, gauge=gauge)

    print(f"[mass-gap] Using gauge={gauge}, ensemble_size={ensemble_size}")
    print(f"[mass-gap] flip template: {flip_template}")
    print(f"[mass-gap] kernel spec:   {kernel_spec}")

    # Sweep over all grid points
    for b, k, n0, L in sweep_iter(cfg):
        seed_all(b, k, n0, L)

        # Resolve per-L paths
        flip_counts_path = format_template(
            flip_template, L=L, gauge=gauge, b=b, k=k, n0=n0
        )
        kernel_path = format_template(
            kernel_spec, L=L, gauge=gauge, b=b, k=k, n0=n0
        )

        print(f"[mass-gap] L={L} → flip_counts={flip_counts_path} | kernel={kernel_path}")

        # run_mass_gap returns a list of floats (one per replicate)
        replicates = run_mass_gap(
            b=b,
            k=k,
            n0=n0,
            L=L,
            ensemble_size=ensemble_size,
            flip_counts_path=flip_counts_path,
            kernel_path=kernel_path,
        )
        arr = np.array(replicates, dtype=float)

        # Summary stats
        m_eff = float(arr.mean()) if arr.size else float("nan")
        m_err = float(arr.std(ddof=1) / np.sqrt(arr.size)) if arr.size > 1 else 0.0

        # Write summary row
        save_csv(
            summary_csv,
            {
                "b": b,
                "k": k,
                "n0": n0,
                "L": L,
                "m_eff": m_eff,
                "m_err": m_err,
                "gauge": gauge,
            },
        )

        # Write each replicate to full CSV
        for m in replicates:
            save_csv(
                full_csv,
                {
                    "b": b,
                    "k": k,
                    "n0": n0,
                    "L": L,
                    "gauge": gauge,
                    "mass_gap": float(m),
                },
            )

    print(f"✅ mass-gap summary → {summary_csv}")
    print(f"✅ mass-gap full results → {full_csv}")


if __name__ == "__main__":
    main()
